{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import io\n",
    "import logging\n",
    "from collections import defaultdict\n",
    "from typing import Dict, Tuple\n",
    "\n",
    "import lightning as L\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from omegaconf import DictConfig, OmegaConf\n",
    "from peft import LoraConfig, PeftModel, get_peft_model\n",
    "from rich import print\n",
    "from torch import Tensor\n",
    "from tqdm.auto import tqdm\n",
    "from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel\n",
    "from transformers.models.clip.modeling_clip import CLIPVisionTransformer\n",
    "\n",
    "from fusion_bench.method import load_algorithm_from_config\n",
    "from fusion_bench.models.linearized.vision_model import (\n",
    "    load_l_lora_vision_model_hf,\n",
    "    load_lora_vision_model_hf,\n",
    ")\n",
    "from fusion_bench.taskpool import load_taskpool_from_config\n",
    "from fusion_bench.method.ties_merging.ties_merging_utils import state_dict_to_vector\n",
    "\n",
    "logging.basicConfig(level=logging.WARN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# replace this using `openai/clip-vit-base-patch16`\n",
    "BASE_MODEL = (\n",
    "    \"/data0/users/tanganke/data/huggingface_models/openai/clip-vit-base-patch16\"\n",
    ")\n",
    "DATASET_NAMES = [\n",
    "    \"sun397\",\n",
    "    \"stanford-cars\",\n",
    "    \"resisc45\",\n",
    "    \"eurosat\",\n",
    "    \"svhn\",\n",
    "    \"gtsrb\",\n",
    "    \"mnist\",\n",
    "    \"dtd\",\n",
    "]\n",
    "\n",
    "\n",
    "def normalize(tensor: Tensor, dim: int = 0, eps: float = 1e-8) -> Tensor:\n",
    "    \"\"\"\n",
    "    Normalizes a tensor along a given dimension.\n",
    "\n",
    "    Args:\n",
    "        tensor (Tensor): The tensor to normalize.\n",
    "        dim (int, optional): The dimension along which to normalize the tensor. Defaults to 0.\n",
    "        eps (float, optional): A small value to add to the denominator to avoid division by zero. Defaults to 1e-8.\n",
    "\n",
    "    Returns:\n",
    "        Tensor: The normalized tensor.\n",
    "    \"\"\"\n",
    "    return tensor / torch.clamp(torch.norm(tensor, dim=dim, keepdim=True), min=eps)\n",
    "\n",
    "\n",
    "def compute_cos_similarity_as_df(\n",
    "    task_vectors_as_state_dicts: Dict[str, Dict[str, Tensor]]\n",
    "):\n",
    "    task_vectors = torch.stack(\n",
    "        [\n",
    "            state_dict_to_vector(task_vectors_as_state_dicts[dataset_name]).double()\n",
    "            for dataset_name in DATASET_NAMES\n",
    "        ]\n",
    "    )\n",
    "    normalized_task_vectors = normalize(task_vectors, dim=1)\n",
    "\n",
    "    results = defaultdict(lambda: list())\n",
    "    for task_0_idx, task_0 in tqdm(enumerate(DATASET_NAMES), total=len(DATASET_NAMES)):\n",
    "        for task_1_idx, task_1 in enumerate(DATASET_NAMES):\n",
    "            results[\"task:0\"].append(task_0)\n",
    "            results[\"task:1\"].append(task_1)\n",
    "            results[\"cosine_similarity\"].append(\n",
    "                F.cosine_similarity(\n",
    "                    normalized_task_vectors[task_0_idx],\n",
    "                    normalized_task_vectors[task_1_idx],\n",
    "                    dim=0,\n",
    "                ).item()\n",
    "            )\n",
    "    return pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Task arithmetic on lora models.\n",
    "\n",
    "$$ \\theta_{merged} = \\theta_{pretrained} + \\lambda \\sum_{i=1}^{n} \\theta_{i} - \\theta_{pretrained} $$\n",
    "\n",
    "Typically, there are two ways to do this:\n",
    "\n",
    "1. Perform the arithmetic operation on the adapters. Such as AdapterSoup, LoraHub, and L-LoRA paper. But this needs to save the initialization state of the adapters, which is not convenient for deployment.\n",
    "2. Perform the arithmetic operation in the original model weight space, i.e after merge and unload the adapters.\n",
    "\n",
    "I choose the second way here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This algorithm object can be used to merge the weights of the pre-trained model and the fine-tuned model.\n",
    "alg = load_algorithm_from_config(\n",
    "    DictConfig({\"name\": \"task_arithmetic\", \"scaling_factor\": 0.3})\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following taskpool is used to evaluate the performance of the merged model. Refer to `config/taskpool/clip-vit-classification_TA8.yaml`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# I use some local paths here, you can replace them with your own paths.\n",
    "taskpool_config = f\"\"\"\n",
    "type: clip_vit_classification\n",
    "name: clip-vit-classification_TA8\n",
    "\n",
    "dataset_type: huggingface_image_classification\n",
    "tasks:\n",
    "  - name: sun397\n",
    "    dataset:\n",
    "      name: tanganke/sun397\n",
    "      path: /data0/users/tanganke/data/huggingface_datasets/tanganke/sun397\n",
    "      split: test\n",
    "  - name: stanford_cars\n",
    "    dataset:\n",
    "      name: tanganke/stanford_cars\n",
    "      path: /data0/users/tanganke/data/huggingface_datasets/tanganke/stanford_cars\n",
    "      split: test\n",
    "  - name: resisc45\n",
    "    dataset:\n",
    "      name: tanganke/resisc45\n",
    "      path: /data0/users/tanganke/data/huggingface_datasets/tanganke/resisc45\n",
    "      split: test\n",
    "  - name: eurosat\n",
    "    dataset:\n",
    "      name: tanganke/eurosat\n",
    "      path: /data0/users/tanganke/data/huggingface_datasets/tanganke/eurosat\n",
    "      split: test\n",
    "  - name: svhn\n",
    "    dataset:\n",
    "      type: instantiate\n",
    "      name: svhn\n",
    "      object: \n",
    "        _target_: datasets.load_dataset\n",
    "        _args_:\n",
    "          - /data0/users/tanganke/data/huggingface_datasets/ufldl-stanford/svhn\n",
    "          - cropped_digits\n",
    "        split: test\n",
    "  - name: gtsrb\n",
    "    dataset:\n",
    "      name: tanganke/gtsrb\n",
    "      path: /data0/users/tanganke/data/huggingface_datasets/tanganke/gtsrb\n",
    "      split: test\n",
    "  - name: mnist\n",
    "    dataset:\n",
    "      name: mnist\n",
    "      path: /data0/users/tanganke/data/huggingface_datasets/ylecun/mnist\n",
    "      split: test\n",
    "  - name: dtd\n",
    "    dataset:\n",
    "      name: tanganke/dtd\n",
    "      path: /data0/users/tanganke/data/huggingface_datasets/tanganke/dtd\n",
    "      split: test\n",
    "\n",
    "clip_model: {BASE_MODEL}\n",
    "batch_size: 128\n",
    "num_workers: 16\n",
    "fast_dev_run: false\n",
    "\n",
    "\"\"\"\n",
    "taskpool_config = OmegaConf.load(io.StringIO(taskpool_config))\n",
    "taskpool = load_taskpool_from_config(taskpool_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LoRA models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "load the pretrained model and fine-tuned models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d720cb45f99d4404808317692abfce18",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# for task arithmetic, we need a pre-trained model and several fine-tuned models\n",
    "models = {\"_pretrained_\": CLIPVisionModel.from_pretrained(BASE_MODEL).vision_model}\n",
    "\n",
    "# load the fine-tuned models\n",
    "for dataset_name in tqdm(DATASET_NAMES):\n",
    "    models[dataset_name] = load_lora_vision_model_hf(\n",
    "        f\"{BASE_MODEL}\",\n",
    "        f\"tanganke/clip-vit-base-patch16_{dataset_name.replace('_','-')}_lora-16\",\n",
    "    ).merge_and_unload()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "merge them into a merged_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Profiler Report\n",
      "\n",
      "--------------------------------------------------------------------------------------------------------------------------\n",
      "|  Action       \t|  Mean duration (s)\t|  Num calls      \t|  Total time (s) \t|  Percentage %   \t|\n",
      "--------------------------------------------------------------------------------------------------------------------------\n",
      "|  Total        \t|  -              \t|  21             \t|  63.559         \t|  100 %          \t|\n",
      "--------------------------------------------------------------------------------------------------------------------------\n",
      "|  load model   \t|  0.03468        \t|  11             \t|  0.38148        \t|  0.6002         \t|\n",
      "|  merge weights\t|  0.024422       \t|  10             \t|  0.24422        \t|  0.38423        \t|\n",
      "--------------------------------------------------------------------------------------------------------------------------\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# the models should be homogenous, here they are all `CLIPVisionTransformer`\n",
    "merged_model = alg.run(models)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3473eaaf27c14821828523341d313f04",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Evaluating tasks:   0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data0/users/tanganke/anaconda3/envs/fusionbench/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "59285e4fbbf34cb3b7c40a681a82777e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d11463ce8237437392badd7f89702ccd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7021382d06e049cbb31121ceb13015a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d0e7866e7ef14085915e31cc23614ed4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "979e71c786b949feabc2b9f4a3b5b808",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Evaluating:   0%|          | 0/156 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data0/users/tanganke/anaconda3/envs/fusionbench/lib/python3.12/site-packages/PIL/TiffImagePlugin.py:900: UserWarning: Truncated File Read\n",
      "  warnings.warn(str(msg))\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c9596bb3ddd24aa3b1f7b39589219ea4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Evaluating:   0%|          | 0/63 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8920d9c0198e484c94f2739ee6fc82a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "626732ddb79446c68e077ab491139456",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Evaluating:   0%|          | 0/22 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9424ef5649ae44d6a0792f9878e78247",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Evaluating:   0%|          | 0/204 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8770fe8028994c9d8f60920379492b37",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Evaluating:   0%|          | 0/99 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2b7ae88675d94e8f9966a0e52d7a18d6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3a4b4607fc374e92a3dc15c3c4773186",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "report = taskpool.evaluate(merged_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span>\n",
       "    <span style=\"color: #008000; text-decoration-color: #008000\">'model_info'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'trainable_params'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">85799424</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'all_params'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">85799424</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'trainable_percentage'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1.0</span><span style=\"font-weight: bold\">}</span>,\n",
       "    <span style=\"color: #008000; text-decoration-color: #008000\">'sun397'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6784383058547974</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'loss'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1.1047296524047852</span><span style=\"font-weight: bold\">}</span>,\n",
       "    <span style=\"color: #008000; text-decoration-color: #008000\">'stanford_cars'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6798905730247498</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'loss'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8926140069961548</span><span style=\"font-weight: bold\">}</span>,\n",
       "    <span style=\"color: #008000; text-decoration-color: #008000\">'resisc45'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.7342857122421265</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'loss'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.794391930103302</span><span style=\"font-weight: bold\">}</span>,\n",
       "    <span style=\"color: #008000; text-decoration-color: #008000\">'eurosat'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.8292592763900757</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'loss'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.5189533233642578</span><span style=\"font-weight: bold\">}</span>,\n",
       "    <span style=\"color: #008000; text-decoration-color: #008000\">'svhn'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.795712947845459</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'loss'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.7038500308990479</span><span style=\"font-weight: bold\">}</span>,\n",
       "    <span style=\"color: #008000; text-decoration-color: #008000\">'gtsrb'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.6212984919548035</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'loss'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1.2499028444290161</span><span style=\"font-weight: bold\">}</span>,\n",
       "    <span style=\"color: #008000; text-decoration-color: #008000\">'mnist'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9150999784469604</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'loss'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.4461774230003357</span><span style=\"font-weight: bold\">}</span>,\n",
       "    <span style=\"color: #008000; text-decoration-color: #008000\">'dtd'</span>: <span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'accuracy'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.5234042406082153</span>, <span style=\"color: #008000; text-decoration-color: #008000\">'loss'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1.7356970310211182</span><span style=\"font-weight: bold\">}</span>\n",
       "<span style=\"font-weight: bold\">}</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m{\u001b[0m\n",
       "    \u001b[32m'model_info'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'trainable_params'\u001b[0m: \u001b[1;36m85799424\u001b[0m, \u001b[32m'all_params'\u001b[0m: \u001b[1;36m85799424\u001b[0m, \u001b[32m'trainable_percentage'\u001b[0m: \u001b[1;36m1.0\u001b[0m\u001b[1m}\u001b[0m,\n",
       "    \u001b[32m'sun397'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.6784383058547974\u001b[0m, \u001b[32m'loss'\u001b[0m: \u001b[1;36m1.1047296524047852\u001b[0m\u001b[1m}\u001b[0m,\n",
       "    \u001b[32m'stanford_cars'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.6798905730247498\u001b[0m, \u001b[32m'loss'\u001b[0m: \u001b[1;36m0.8926140069961548\u001b[0m\u001b[1m}\u001b[0m,\n",
       "    \u001b[32m'resisc45'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.7342857122421265\u001b[0m, \u001b[32m'loss'\u001b[0m: \u001b[1;36m0.794391930103302\u001b[0m\u001b[1m}\u001b[0m,\n",
       "    \u001b[32m'eurosat'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.8292592763900757\u001b[0m, \u001b[32m'loss'\u001b[0m: \u001b[1;36m0.5189533233642578\u001b[0m\u001b[1m}\u001b[0m,\n",
       "    \u001b[32m'svhn'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.795712947845459\u001b[0m, \u001b[32m'loss'\u001b[0m: \u001b[1;36m0.7038500308990479\u001b[0m\u001b[1m}\u001b[0m,\n",
       "    \u001b[32m'gtsrb'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.6212984919548035\u001b[0m, \u001b[32m'loss'\u001b[0m: \u001b[1;36m1.2499028444290161\u001b[0m\u001b[1m}\u001b[0m,\n",
       "    \u001b[32m'mnist'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.9150999784469604\u001b[0m, \u001b[32m'loss'\u001b[0m: \u001b[1;36m0.4461774230003357\u001b[0m\u001b[1m}\u001b[0m,\n",
       "    \u001b[32m'dtd'\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.5234042406082153\u001b[0m, \u001b[32m'loss'\u001b[0m: \u001b[1;36m1.7356970310211182\u001b[0m\u001b[1m}\u001b[0m\n",
       "\u001b[1m}\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(report)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fusionbench",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
